import copy
import math
import numpy as np
import torch
from model.segnet import SegNet
import cv2
import time
import torch.nn as nn
import os

##color##
colormap = [[0, 0, 0], [0, 255, 0]]
ndcolormap = np.array(colormap)
##color##

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model_path = r'weight/segnet_epoch_500.pth'
shape = (256,256)


def demo(net,img_path):
    image = cv2.imread(img_path)
    image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
    image = cv2.resize(image,shape)
    draw_img = copy.deepcopy(image)
    image = image.astype(np.float32)
    image = image.transpose((2,0,1))
    image = torch.from_numpy(image)/255.
    image = image.to(device, dtype=torch.float32)
    image = image.unsqueeze(0)


    net.eval()
    with torch.no_grad():
        output = net(image)
        probs = torch.sigmoid(output)
        probs = probs.squeeze(0)
        full_mask = probs.squeeze().cpu().numpy()
        full_mask[full_mask>0.05] = 1
        mask = full_mask.astype(np.uint8)
        mask = ndcolormap[mask].astype(np.uint8)

    output = cv2.addWeighted(draw_img,0.5,mask,0.5,0)

    return output

if __name__ == '__main__':
    net = SegNet(3,1)
    net.to(device=device)
    net = torch.load(model_path, map_location=device)
    path = r'train_data/test_Crack'
    save = r'output'
    for img_name in os.listdir(path):
        img_path = os.path.join(path,img_name)
        out_path = os.path.join(save,img_name)
        # a = time.time()
        mask = demo(net,img_path)
        # print(time.time()-a)
        cv2.imwrite(out_path, mask)